{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction to Planning for Self Driving Vehicles\n",
    "\n",
    "In this notebook you are going to train your own ML policy to fully control an SDV. You will train your model using the Woven by Toyota Prediction Dataset and [L5Kit](https://github.com/woven-planet/l5kit).\n",
    "\n",
    "**Before starting, please download the [Woven by Toyota Prediction Dataset 2020](https://woven.toyota/en/prediction-dataset) and follow [the instructions](https://github.com/woven-planet/l5kit#download-the-datasets) to correctly organise it.**\n",
    "\n",
    "The policy will be a deep neural network (DNN) which will be invoked by the SDV to obtain the next command to execute.\n",
    "\n",
    "More in details, you will be working with a CNN architecture based on ResNet50.\n",
    "\n",
    "![model](../../docs/images/planning/model.svg)\n",
    "\n",
    "\n",
    "#### Inputs\n",
    "The network will receive a Bird's-Eye-View (BEV) representation of the scene surrounding the SDV as the only input. This has been rasterised in a fixed grid image to comply with the CNN input. L5Kit is shipped with various rasterisers. Each one of them captures different aspects of the scene (e.g. lanes or satellite view).\n",
    "\n",
    "This input representation is very similar to the one used in the [prediction competition](https://www.kaggle.com/c/lyft-motion-prediction-autonomous-vehicles/overview). Please refer to our [competition baseline notebook](../agent_motion_prediction/agent_motion_prediction.ipynb) and our [data format notebook](../visualisation/visualise_data.ipynb) if you want to learn more about it.\n",
    "\n",
    "#### Outputs\n",
    "The network outputs the driving signals required to fully control the SDV. In particular, this is a trajectory of XY and yaw displacements which can be used to move and steer the vehicle.\n",
    "\n",
    "After enough training, your model will be able to drive an agent along a specific route. Among others, it will do lane-following while respecting traffic lights.\n",
    "\n",
    "Let's now focus on how to train this model on the available data.\n",
    "\n",
    "### Training using imitation learning\n",
    "The model is trained using a technique called *imitation learning*. We feed examples of expert driving experiences to the model and expect it to take the same actions as the driver did in those episodes. Imitation Learning is a subfield of supervised learning, in which a model tries to learn a function f: X -> Y describing given input / output pairs - one prominent example of this is image classification.\n",
    "\n",
    "This is also the same concept we use in our [motion prediction notebook](../agent_motion_prediction/agent_motion_prediction.ipynb), so feel free to check that out too.\n",
    "\n",
    "##### Imitation learning limitations\n",
    "\n",
    "Imitation Learning is powerful, but it has a strong limitation. It's not trivial for a trained model to generalise well on out-of-distribution data.\n",
    "\n",
    "After training the model, we would like it to take full control and drive the AV in an autoregressive fashion (i.e. by following its own predictions).\n",
    "\n",
    "During evaluation it's very easy for errors to compound and make the AV drift away from the original distribution. In fact, during training our model has seen only good examples of driving. In particular, this means **almost perfect midlane following**. However, even a small constant displacement during evaluation can accumulate enough error to lead the AV completely out of its distribution in a matter of seconds.\n",
    "\n",
    "![drifting](../../docs/images/planning/drifting.svg)\n",
    "\n",
    "This is a well known issue in SDV control and simulation discussed, among others, in [this article](https://ri.cmu.edu/pub_files/2010/5/Ross-AIStats10-paper.pdf).\n",
    "\n",
    "# Adding perturbations to the mix\n",
    "\n",
    "One of the simplest techniques to ensure a good generalisation is **data augmentation**, which exposes the network to different versions of the input and helps it to generalise better to out-of-distribution situations.\n",
    "\n",
    "In our setting, we want to ensure that **our model can recover if it ends up slightly off the midlane it is following**.\n",
    "\n",
    "Following [the noteworthy approach from Waymo](https://arxiv.org/pdf/1812.03079.pdf), we can enrich the training set with **online trajectory perturbations**. These perturbations are kinematically feasible and affect both starting angle and position. A new ground truth trajectory is then generated to link this new starting point with the original trajectory end point. These starting point will be slightly rotated and off the original midlane, and the new trajectory will teach the model how to recover from this situation.\n",
    "\n",
    "![perturbation](../../docs/images/planning/perturb.svg)\n",
    "\n",
    "\n",
    "In the following cell, we load the training data and leverage L5Kit to add these perturbations to our training set.\n",
    "We also plot the same example with and without perturbation. During training, our model will see also those examples and learn how to recover from positional and angular offsets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tempfile import gettempdir\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "\n",
    "from l5kit.configs import load_config_data\n",
    "from l5kit.data import LocalDataManager, ChunkedDataset\n",
    "from l5kit.dataset import EgoDataset\n",
    "from l5kit.rasterization import build_rasterizer\n",
    "from l5kit.geometry import transform_points\n",
    "from l5kit.visualization import TARGET_POINTS_COLOR, draw_trajectory\n",
    "from l5kit.planning.rasterized.model import RasterizedPlanningModel\n",
    "from l5kit.kinematic import AckermanPerturbation\n",
    "from l5kit.random import GaussianRandomGenerator\n",
    "\n",
    "import os"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare data path and load cfg\n",
    "\n",
    "By setting the `L5KIT_DATA_FOLDER` variable, we can point the script to the folder where the data lies.\n",
    "\n",
    "Then, we load our config file with relative paths and other configurations (rasteriser, training params...)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set env variable for data\n",
    "os.environ[\"L5KIT_DATA_FOLDER\"] = \"/tmp/l5kit_data\"\n",
    "dm = LocalDataManager(None)\n",
    "# get config\n",
    "cfg = load_config_data(\"./config.yaml\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "perturb_prob = cfg[\"train_data_loader\"][\"perturb_probability\"]\n",
    "\n",
    "# rasterisation and perturbation\n",
    "rasterizer = build_rasterizer(cfg, dm)\n",
    "mean = np.array([0.0, 0.0, 0.0])  # lateral, longitudinal and angular\n",
    "std = np.array([0.5, 1.5, np.pi / 6])\n",
    "perturbation = AckermanPerturbation(\n",
    "        random_offset_generator=GaussianRandomGenerator(mean=mean, std=std), perturb_prob=perturb_prob)\n",
    "\n",
    "# ===== INIT DATASET\n",
    "train_zarr = ChunkedDataset(dm.require(cfg[\"train_data_loader\"][\"key\"])).open()\n",
    "train_dataset = EgoDataset(cfg, train_zarr, rasterizer, perturbation)\n",
    "\n",
    "# plot same example with and without perturbation\n",
    "for perturbation_value in [1, 0]:\n",
    "    perturbation.perturb_prob = perturbation_value\n",
    "\n",
    "    data_ego = train_dataset[0]\n",
    "    im_ego = rasterizer.to_rgb(data_ego[\"image\"].transpose(1, 2, 0))\n",
    "    target_positions = transform_points(data_ego[\"target_positions\"], data_ego[\"raster_from_agent\"])\n",
    "    draw_trajectory(im_ego, target_positions, TARGET_POINTS_COLOR)\n",
    "    plt.imshow(im_ego)\n",
    "    plt.axis('off')\n",
    "    plt.show()\n",
    "\n",
    "# before leaving, ensure perturb_prob is correct\n",
    "perturbation.perturb_prob = perturb_prob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = RasterizedPlanningModel(\n",
    "        model_arch=\"resnet50\",\n",
    "        num_input_channels=rasterizer.num_channels(),\n",
    "        num_targets=3 * cfg[\"model_params\"][\"future_num_frames\"],  # X, Y, Yaw * number of future states,\n",
    "        weights_scaling= [1., 1., 1.],\n",
    "        criterion=nn.MSELoss(reduction=\"none\")\n",
    "        )\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prepare for training\n",
    "Our `EgoDataset` inherits from PyTorch `Dataset`; so we can use it inside a `Dataloader` to enable multi-processing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_cfg = cfg[\"train_data_loader\"]\n",
    "train_dataloader = DataLoader(train_dataset, shuffle=train_cfg[\"shuffle\"], batch_size=train_cfg[\"batch_size\"], \n",
    "                             num_workers=train_cfg[\"num_workers\"])\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "model = model.to(device)\n",
    "optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "\n",
    "print(train_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training loop\n",
    "Here, we purposely include a barebone training loop. Clearly, many more components can be added to enrich logging and improve performance. Still, the sheer size of our dataset ensures that a reasonable performance can be obtained even with this simple loop."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tr_it = iter(train_dataloader)\n",
    "progress_bar = tqdm(range(cfg[\"train_params\"][\"max_num_steps\"]))\n",
    "losses_train = []\n",
    "model.train()\n",
    "torch.set_grad_enabled(True)\n",
    "\n",
    "for _ in progress_bar:\n",
    "    try:\n",
    "        data = next(tr_it)\n",
    "    except StopIteration:\n",
    "        tr_it = iter(train_dataloader)\n",
    "        data = next(tr_it)\n",
    "    # Forward pass\n",
    "    data = {k: v.to(device) for k, v in data.items()}\n",
    "    result = model(data)\n",
    "    loss = result[\"loss\"]\n",
    "    # Backward pass\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    losses_train.append(loss.item())\n",
    "    progress_bar.set_description(f\"loss: {loss.item()} loss(avg): {np.mean(losses_train)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot the train loss curve\n",
    "We can plot the train loss against the iterations (batch-wise) to check if our model has converged."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(np.arange(len(losses_train)), losses_train, label=\"train loss\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Store the model\n",
    "\n",
    "Let's store the model as a torchscript. This format allows us to re-load the model and weights without requiring the class definition later.\n",
    "\n",
    "**Take note of the path, you will use it later to evaluate your planning model!**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "to_save = torch.jit.script(model.cpu())\n",
    "path_to_save = f\"{gettempdir()}/planning_model.pt\"\n",
    "to_save.save(path_to_save)\n",
    "print(f\"MODEL STORED at {path_to_save}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Congratulations in training your first ML policy for planning!\n",
    "### What's Next\n",
    "\n",
    "Now that your model is trained and safely stored, you can evaluate how it performs in two very different situations using our dedicated notebooks:\n",
    "\n",
    "### [Open-loop evaluation](./open_loop_test.ipynb)\n",
    "In this setting the model **is not controlling the AV**, and predictions are used to compute metrics only.\n",
    "\n",
    "### [Closed-loop evaluation](./closed_loop_test.ipynb)\n",
    "In this setting the model **is in full control of the AV** future movements.\n",
    "\n",
    "## Pre-trained models\n",
    "we provide a collection of pre-trained models for the planning task:\n",
    "- [model](https://d20lyvjneielsk.cloudfront.net/planning_model_20201208.pt) trained on  train.zarr for 15 epochs;\n",
    "- [model](https://d20lyvjneielsk.cloudfront.net/planning_model_20201208_early.pt) trained on train.zarr for 2 epochs;\n",
    "- [model](https://d20lyvjneielsk.cloudfront.net/planning_model_20201208_nopt.pt) trained on train.zarr with perturbations disabled for 15 epochs;\n",
    "- [model](https://d20lyvjneielsk.cloudfront.net/planning_model_20201208_nopt_early.pt) trained on train.zarr with perturbations disabled for 2 epochs;\n",
    "\n",
    "We include two partially trained models to emphasise the important role of perturbations during training, especially during the first stage of training.\n",
    "\n",
    "To use one of the models simply download the corresponding `.pt` file and load it in the evaluation notebooks."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
